package com.xz.tree;

import java.util.Comparator;
import java.util.Map;

/**
 * @author XZ
 * @create 2021-06-23 20:32
 */
public class AVLTree<E> extends BST<E>{
    public AVLTree() {
        this(null);
    }

    public AVLTree(Comparator<E> comparator) {
        super(comparator);
    }

    @Override
    protected void afterAdd(Node<E> node) {
        while ((node = node.parent) != null) {
            if (isBalance(node)) {
                //更新高度
                updateHeight(node);

            } else {
                //恢复平衡
                rebalance(node);
                //整个树都恢复平衡了。
                break;
            }
        }
    }

    @Override
    protected void afterRemove(Node<E> node) {
        while ((node = node.parent) != null) {
            if (isBalance(node)) {
                //更新高度
                updateHeight(node);
            } else {
                //恢复平衡
                rebalance(node);
            }
        }
    }

    @Override
    protected Node<E> createNode(E element, Node<E> parent) {
        return new AVLNode<>(element, parent);
    }
    /**
     * 恢复平衡
     * @param grand：高度最低都不平衡节点
     */
    private void rebalance(Node<E> grand) {
        Node<E> parent = ((AVLNode<E>) grand).tallerChild();
        Node<E> node = ((AVLNode<E>) parent).tallerChild();
        if (parent.isLeftChild()) {
            if (node.isLeftChild()) {//LL
                rotateRight(grand);
            } else {//LR
                rotateLeft(parent);
                rotateRight(grand);
            }
        } else {
            if (node.isLeftChild()) {//RL
                rotateRight(parent);
                rotateLeft(grand);
            } else {//RR
                rotateLeft(grand);
            }
        }
    }

    //错误
    private void rebalance2(Node<E> grand) {
        Node<E> parent = ((AVLNode<E>) grand).tallerChild();
        Node<E> node = ((AVLNode<E>) parent).tallerChild();
        if (parent.isLeftChild()) {
            if (node.isLeftChild()) {//LL
                rotate(grand, node.left, node, node.right, parent, parent.right, grand, grand.right);
            } else {//LR
                rotate(grand,parent.left,parent,node.left,node,node.right,grand,grand.right);
            }
        } else {
            if (node.isLeftChild()) {//RL
                rotate(grand,grand.left,grand,node.left,node.right,parent,parent.right,parent.right);
            } else {//RR
                rotate(grand,grand.left,grand,parent.left,parent,node.left,node,node.right);
            }
        }
    }

    private void rotate(Node<E> r,
                        Node<E> a,Node<E> b,Node<E> c,
                        Node<E> d,
                        Node<E> e,Node<E> f,Node<E> g) {

        d.parent = r.parent;
        //让d成为这棵子树的根节点
        if (r.isLeftChild()) {
            r.parent.left=d;
        } else if (r.isRightChild()) {
            r.parent.right = d;
        } else {
            root = d;
        }

        //a-b-c
        b.left=a;
        if (a != null) {
            a.parent=b;
        }
        b.right=c;
        if (c != null) {
            c.parent=b;
        }
        updateHeight(b);

        //e-f-g
        f.left=e;
        if (e != null) {
            e.parent=f;
        }
        f.right=g;
        if (g!= null) {
            g.parent=f;
        }
        updateHeight(f);

        //b-d-f
        d.left=b;
        d.right=f;
        b.parent=d;
        f.parent=d;
        updateHeight(d);


    }

    private void rotateLeft(Node<E> grand) {
        Node<E> parent = grand.right;
        Node<E> child = parent.left;

        grand.right = parent.left;
        parent.left = grand;
        afterRotate(grand, parent, child);
    }



    private void rotateRight(Node<E> grand) {
        Node<E> parent =grand.left;
        Node<E> child = parent.right;
        grand.left = child;
        parent.right = grand;

        afterRotate(grand, parent, child);
    }
    private void afterRotate(Node<E> grand, Node<E> parent, Node<E> child) {

        //让parent成为子树的根节点
        parent.parent = grand.parent;
        if (grand.isLeftChild()) {
            grand.parent.left = parent;
        } else if (grand.isRightChild()) {
            grand.parent.right = parent;
        } else {//grand是root节点
            root = parent;
        }

        //更新child的parent
        if (child != null) {
            child.parent=grand;
        }

        //更新grand的parent
        grand.parent = parent;

        //更新高度
        updateHeight(grand);
        updateHeight(parent);
    }
    private boolean isBalance(Node<E> node) {
        return Math.abs(((AVLNode<E>) node).balanceFactor()) <= 1;
    }

    private void updateHeight(Node<E> node) {
        ((AVLNode<E>) node).updateHeight();
    }

    private class AVLNode<E> extends Node {
        int height=1;

        public AVLNode(Object element, Node parent) {
            super(element, parent);
        }

        public int balanceFactor() {
            int leftHeight = left == null ? 0 : ((AVLNode<E>)left).height;
            int rightHeight = right == null ? 0 : ((AVLNode<E>)right).height;
            return leftHeight - rightHeight;
        }

        public void updateHeight() {
            int leftHeight = left == null ? 0 : ((AVLNode<E>)left).height;
            int rightHeight = right == null ? 0 : ((AVLNode<E>)right).height;
            height = 1 + Math.max(leftHeight, rightHeight);
        }

        public Node<E> tallerChild() {
            int leftHeight = left == null ? 0 : ((AVLNode<E>)left).height;
            int rightHeight = right == null ? 0 : ((AVLNode<E>)right).height;
            if (leftHeight>rightHeight) return left;
            if (leftHeight<rightHeight) return right;
            return isLeftChild() ? left : right;
        }
    }
}
